package mysql

import (
	"context"
	"fmt"
	_ "github.com/go-sql-driver/mysql"
	"github.com/spf13/viper"
	"go-beego-api/component/ck_log"
	"go.uber.org/zap"
	"regexp"
	"strings"
	"sync"
	"xorm.io/xorm"
	"xorm.io/xorm/contexts"
	"xorm.io/xorm/names"
)

type Options struct {
	Dns         string
	MaxIdleConn int
	MaxOpenConn int
}

var sessionList = sync.Map{}

var columnMapper names.Mapper

func InitMySQL() {
	nodes := viper.GetStringMap("mysql")
	for node := range nodes {
		masterEngine, err := xorm.NewEngine("mysql", viper.GetString("mysql."+node+".dsn"))
		slaves := viper.GetStringSlice("mysql." + node + ".slaves")
		slaveEngines := make([]*xorm.Engine, len(slaves))
		for i, slave := range slaves {
			engine, err := xorm.NewEngine("mysql", slave)
			if err != nil {
				panic("create slave [" + node + "] connection failed:" + err.Error())
			}
			slaveEngines[i] = engine
		}
		engineGroup, err := xorm.NewEngineGroup(masterEngine, slaveEngines)
		if err != nil {
			panic("mysql node [" + node + "] connection failed: " + err.Error())
		}

		engineGroup.SetMaxIdleConns(viper.GetInt("mysql." + node + ".pool.idle"))
		engineGroup.SetMaxOpenConns(viper.GetInt("mysql." + node + ".pool.active"))

		if columnMapper != nil {
			engineGroup.SetColumnMapper(columnMapper)
		}
		engineGroup.AddHook(NewTracingHook())
		sessionList.Store(node, engineGroup)
		//masterEngine.Sync2(new(User),new(WeixinOfficialAccount))
	}

}

func GetSession(args ...string) *xorm.EngineGroup {
	name := "default"
	if len(args) > 0 {
		name = args[0]
	}
	if v, ok := sessionList.Load(name); ok {
		return v.(*xorm.EngineGroup)
	}
	return nil
}

// SetColumnMapper 设置字段映射方案
func SetColumnMapper(mapper names.Mapper) {
	columnMapper = mapper
}

type TracingHook struct {
	// 注意Hook伴随DB实例的生命周期，所以我们不能在Hook里面寄存span变量
	// 否则就会发生并发问题
	before func(c *contexts.ContextHook) (context.Context, error)
	after  func(c *contexts.ContextHook) error
}

// xorm的hook接口需要满足BeforeProcess和AfterProcess函数
func (h *TracingHook) BeforeProcess(c *contexts.ContextHook) (context.Context, error) {
	return h.before(c)
}

func (h *TracingHook) AfterProcess(c *contexts.ContextHook) error {
	return h.after(c)
}

func before(c *contexts.ContextHook) (context.Context, error) {
	// c.ctx = context.WithValue(c.ctx, "mysqlStart", time.Now())
	return c.Ctx, nil
}

func after(c *contexts.ContextHook) error {
	//if c.Err != nil { // 记录错误日志
	//	ck_log.LogCtx(c.Ctx).Error(c.Err)
	//}
	l := ck_log.LogCtx(c.Ctx)
	sqlStmt := c.SQL
	args := make([]interface{}, len(c.Args))
	copy(args, c.Args)
	if strings.Contains(strings.ToLower(sqlStmt), "mobile") || strings.Contains(strings.ToLower(sqlStmt), "phone") {
		sqlStmt = FilterPhone(sqlStmt)
		if len(args) > 0 {
			for k, v := range args {
				switch v.(type) {
				case string:
					args[k] = FilterPhone(v.(string))
				}
			}
		}
	}
	if strings.Contains(strings.ToLower(sqlStmt), "mail") {
		sqlStmt = FilterMail(sqlStmt)
		if len(args) > 0 {
			for k, v := range args {
				switch v.(type) {
				case string:
					args[k] = FilterMail(v.(string))
				}
			}
		}
	}
	l = l.With(ck_log.PHYLUM, "MySQL", "SQL", sqlStmt, "Args", args, ck_log.EXECUTIONTIME, c.ExecuteTime.Milliseconds())
	if c.Err != nil {
		l.With(zap.StackSkip("StackSkip", 9)).Error(c.Err)
		return nil
	}
	sqlStr := strings.ToUpper(c.SQL)
	if strings.HasPrefix(sqlStr, "SELECT") {
		l.Info(" ")
		return nil
	}
	if c.Result != nil {
		affect, err := c.Result.RowsAffected()
		msg := ""
		if err == nil {
			// l.Infof("Affected:%d", affect)
			msg = fmt.Sprintf("Affected:%d", affect)
		}
		lastInsertId, err := c.Result.LastInsertId()
		if err == nil {
			// l.Infof("LastInsertId:%d", lastInsertId)
			msg = fmt.Sprintf("%s LastInsertId:%d", msg, lastInsertId)
		}
		l.With(zap.StackSkip("StackSkip", 50)).Infof(msg)
		//l.Infof(msg)
	}
	return nil
}

func NewTracingHook() *TracingHook {
	return &TracingHook{
		before: before,
		after:  after,
	}
}

func FilterPhone(src string) string {
	re1, _ := regexp.Compile(`1\d{10}`)
	if loc1 := re1.FindStringIndex(src); len(loc1) != 0 {
		return src[0:loc1[0]+3] + "****" + src[loc1[1]-4:]

	}
	return src
}

func FilterMail(src string) string {
	re2, _ := regexp.Compile(`[a-zA-Z0-9]{3,20}@[a-zA-Z0-9]{2,10}[.](com|cn|org)`)
	if loc2 := re2.FindStringIndex(src); len(loc2) != 0 {
		return src[0:loc2[0]] + "****@****" + src[loc2[1]:]

	}
	return src
}

// 让编译器知道这个是xorm的Hook，防止编译器无法检查到异常
var _ contexts.Hook = &TracingHook{}
